import os
import sys
sys.path.extend(["../"]) # pylint: disable=wrong-import-position

import pickle
from collections import defaultdict

from utils import load_mlp_conf_with_default, load_sdmp_conf_with_default
from data_utils import load_data, SDMPDataPre
from graph_dict import SDMP
from model import MLP

# def save_listlist_csv(foo, path):
#     with open(path, "w") as fout:
        
        
def load_SDMP(result_root_path,
              TARGET_GNN_FOLDER,
              DATA_ROOT_FOLDER,
              device="cpu"):
    """
    Load SDMP and preprocessor
    """
    SDMP_CONF_PATH = os.path.join(result_root_path, "dict_conf.yml")
    MLP_CONF_PATH = os.path.join(result_root_path, "conf.yml")

    # loading data
    train_conf = load_mlp_conf_with_default(MLP_CONF_PATH)
    sdmp_conf = load_sdmp_conf_with_default(SDMP_CONF_PATH)
    print(train_conf)
    with open(os.path.join(result_root_path, "GNN_data_split_seed.txt"), 'r') as fin:
        seed_str = fin.read()
    g = load_data(train_conf["name"], seed=int(eval(seed_str)))

    GNN_MODEL_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_path'])
    GNN_CONF_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_conf_path'])
    GNN_ACC_PATH = os.path.join(TARGET_GNN_FOLDER, sdmp_conf['target_h_model_metric_path'])

    DATA_FOLDER = os.path.join(DATA_ROOT_FOLDER, train_conf["name"])
    if not os.path.exists(DATA_FOLDER):
        os.makedirs(DATA_FOLDER)

    preprocesser = SDMPDataPre(train_conf["name"], sdmp_conf["feature_normalize"],
                            sdmp_conf["target_h_mode"],
                            GNN_CONF_PATH, GNN_MODEL_PATH, sdmp_conf["target_h_model"], 
                            sdmp_conf["h_init_theta_mode"], sdmp_conf["h_init_theta_k"],
                            sdmp_conf["h_init_theta_k_fanout"],
                            sdmp_conf["theta_cand_mode"], sdmp_conf["theta_cand_k2"],
                            sdmp_conf["theta_cand_k1"], sdmp_conf["theta_cand_fanout"],
                            sdmp_conf["theta_cand_add_self"],
                            sdmp_conf,
                            use_cache=True, cache_path=os.path.join(DATA_FOLDER, "SDMPPre"),
                            device=device)
    preprocesser.g = g
    preprocesser.disp_states()
    theta_cand, h_init_theta, X, target = preprocesser.theta_cand, preprocesser.h_init_theta, preprocesser.X, preprocesser.target

    # construct the SDMP
    print("Constructing the SDMP model...")
    sdmp = SDMP(X,
                target,
                theta_cand,
                h_init_theta,
                sdmp_conf,
                device=device,
                verbose=True)
    sdmp.load(result_root_path, log_name="dict_log.pkl")
    return preprocesser, sdmp, g, sdmp_conf, train_conf

def load_SDMP_MLP(train_conf, sdmp, g, result_root_path, device="cpu"):
    print("Constructing the MLP modeling...")
    MLP_STATE_PATH = os.path.join(result_root_path, "models", "state_dict_0")
    in_size = list(sdmp.h.children())[0][-1].out_features
    hidden_size = [train_conf['hidden_size']] * train_conf['hidden_layer']
    out_size = g.num_classes
    model = MLP(in_size, hidden_size, out_size, dropout=train_conf['dropout']).to(device)
    with open(MLP_STATE_PATH, 'rb') as fin:
        model.load_state_dict(pickle.load(fin))
    return model

class confHolder:
    def __init__(self, path_dict, name_model_pool):
        self.path_dict = path_dict
        self.name_model_pool = name_model_pool
        
        self.sdmp_conf, self.mlp_conf = self.load_all()
    
    def load_all(self):
        all_sdmp_conf = dict()
        all_mlp_conf = dict()
        for name, model in self.name_model_pool:
            result_root_path = self.path_dict[(name, model)]["result_root_path"]
            # with open(os.path.join(result_root_path, "GNN_data_split_seed.txt"), 'r') as fin:
            #     seed_str = fin.read()
            # TARGET_GNN_FOLDER = os.path.join(self.path_dict[(name, model)]["TARGET_GNN_parent_folder"], "seed_"+seed_str)
            SDMP_CONF_PATH = os.path.join(result_root_path, "dict_conf.yml")
            MLP_CONF_PATH = os.path.join(result_root_path, "conf.yml")
            
            mlp_conf = load_mlp_conf_with_default(MLP_CONF_PATH)
            sdmp_conf = load_sdmp_conf_with_default(SDMP_CONF_PATH)
            
            all_sdmp_conf[(name, model)] = sdmp_conf
            all_mlp_conf[(name, model)] = mlp_conf
        
        return all_sdmp_conf, all_mlp_conf
    
    def gen_latex_table(self, all_conf, keys_to_show):
        def parse_one_row(k, v, keys_to_show):
            cur_row = k[0] + "/" + k[1]
            for key in keys_to_show:
                cur_instance = v[key]
                if isinstance(cur_instance, list):
                    cur_row += " & " + ", ".join([str(i) for i in cur_instance])
                else:
                    cur_row += " & " + str(cur_instance)
            return cur_row
                
        res = []
        for k, v in all_conf.items():
            res.append(parse_one_row(k, v, keys_to_show))
            
        return "\\\\ \n".join(res)

class latexTableBase:
    """
    Generate latex table string for analyzeLogBase formatted logs.
    
    It takes a list for the header, followed by list of log dict that is the all_log member of
    analyzeLogBase. And it generates a string for latex table. The method gen_a_row is need to be
    specified for each task and indicating how to transform each dict of log into a row of table.
    post_processing is optional, which ususally takes some ranking tasks for the rows.
    
    Input:
    list_log: list of logs that is the all_log member from analyzeLogBase
    header: the header for the table, could be empty. 
    
    Return:
    A latex table, in which each row is from one result dict and the row is specfically organized
    gen_a_row implementation.
    """
    def __init__(self, list_log, header):
        self.list_log = list_log
        self.header = header
        
        self.table_element = []
        self.table_str_core = self.wrap_table()
        self.header_table = self.gen_table([self.header])
        self.table_str = self.wrap_full_table()
        
    def wrap_full_table(self):
        """
        Generate the ready to paste latex table with predefined style
        """
        prefix = "\\begin{table}[h]\n\t\\centering\n\t\\caption{My Table}\n\t\\begin{tabular}{"
        prefix = prefix + "l" * len(self.header) + "}\n\t\t\\hline\n"
        mid = self.header_table + "\\\\\\hline\n" + self.table_str_core + "\\\\\\hline\n"
        post = "\t\\end{tabular}\n\t\\label{tab:my_tab}\n\\end{table}"
        return prefix + mid + post
 
    def wrap_table(self):
        """
        Generate the core table with list of list
        """
        for this_log in self.list_log:
            this_row = self.gen_a_row(this_log)
            self.table_element.append(this_row)
        self.post_precessing()
        return self.gen_table(self.table_element)
    
    def gen_a_row(self, dict_log):
        raise NotImplementedError("The method gen_a_row is not implemented! This should be "
                                  "specifically implmented for each table. ")
        
    def post_precessing(self):
        print("No post precessing is executed. ")

    @staticmethod
    def gen_table(rows, row_prefix="\t\t"):
        """Generate latex table string according to list of lists"""
        res = [row_prefix + " & ".join(this_row) for this_row in rows]
        res = "\\\\\n".join(res)
        return res
